"""
Phase 5: Equity Markets
========================
Z → stock market valuations and equity positions.
"""

import sys
from pathlib import Path

import numpy as np
import pandas as pd

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/asset_returns")
MULTILATERAL_DIR = PROJECT_DIR.parent / "multilateral"
sys.path.insert(0, str(MULTILATERAL_DIR / "src"))
from model import PanelGLS

PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]

CONTROLS = ['rgdp_growth', 'inflation', 'fiscal_bal_gdp', 'kaopen', 'nfa_gdp_lag']


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.10: return '*'
    return ''


def run_model(df, dep_var, regressors, label):
    regressors = [r for r in regressors if r in df.columns]
    if dep_var not in df.columns:
        print(f"  [{label}] {dep_var} missing — skipping")
        return None

    sub = df.dropna(subset=[dep_var] + regressors).copy()
    if len(sub) < 50:
        print(f"  [{label}] Insufficient obs ({len(sub)}) — skipping")
        return None

    gls = PanelGLS()
    gls.fit(sub[dep_var].values, sub[regressors].values,
            sub['iso3'].values, sub['year'].values)

    print(f"\n  [{label}]  N={gls.n_obs}, countries={gls.n_countries}, "
          f"R²={gls.r_squared:.4f}")

    results = {
        'label': label, 'dep_var': dep_var,
        'n_obs': gls.n_obs, 'n_countries': gls.n_countries,
        'r_squared': gls.r_squared, 'rho': gls.rho,
    }
    for i, name in enumerate(regressors):
        results[f'coef_{name}'] = gls.beta[i]
        results[f'se_{name}'] = gls.se[i]
        results[f'p_{name}'] = gls.pvalues[i]
        sig = stars(gls.pvalues[i])
        print(f"    {name:<25} {gls.beta[i]:>8.4f} ({gls.se[i]:.4f}) {sig}")

    return results


def main():
    print("=" * 70)
    print("PHASE 5: Equity Markets")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "asset_panel.csv")
    print(f"Panel: {df['iso3'].nunique()} countries, {len(df)} obs")

    # Check coverage
    for var in ['stock_market_cap_gdp', 'port_eq_assets_gdp']:
        if var in df.columns:
            n_c = df.loc[df[var].notna(), 'iso3'].nunique()
            n_o = df[var].notna().sum()
            print(f"  {var}: {n_c} countries, {n_o} obs")

    all_results = []
    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = [c for c in CONTROLS if c in df.columns]

    # ── Model 1: stock_market_cap_gdp = Z + controls ──
    r = run_model(df, 'stock_market_cap_gdp', demo_vars + controls,
                  "M1: Z → stock mkt cap/GDP")
    if r: all_results.append(r)

    # ── Model 2: port_eq_assets_gdp = Z + controls ──
    r = run_model(df, 'port_eq_assets_gdp', demo_vars + controls,
                  "M2: Z → portfolio equity/GDP")
    if r: all_results.append(r)

    # ── Model 3: Z × KAOPEN ──
    int_vars = [v for v in ['Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']
                if v in df.columns]
    if int_vars:
        r = run_model(df, 'stock_market_cap_gdp', demo_vars + controls + int_vars,
                      "M3: Z×KAOPEN → stock mkt cap")
        if r: all_results.append(r)

    # ── Model 4: old_dep + youth_dep ──
    age_vars = ['old_dep', 'youth_dep']
    r = run_model(df, 'stock_market_cap_gdp', age_vars + controls,
                  "M4: age ratios → stock mkt cap")
    if r: all_results.append(r)

    # ── Model 5: changes ──
    r = run_model(df, 'd_stock_market_cap', demo_vars + controls,
                  "M5: Z → Δstock mkt cap")
    if r: all_results.append(r)

    # ── Model 6: OECD subsample ──
    oecd = df[df['iso3'].isin(OECD_38)].copy()
    r = run_model(oecd, 'stock_market_cap_gdp', demo_vars + controls,
                  "M6: OECD Z → stock mkt cap")
    if r: all_results.append(r)

    # ── Model 7: Kopecky-Taylor asymmetry ──
    # Compare Z effect on equity vs safe rate (within OECD)
    r_eq = run_model(oecd, 'stock_market_cap_gdp', demo_vars + controls,
                     "M7a: OECD Z → equity (KT)")
    r_rate = run_model(oecd, 'real_bond_10y', demo_vars + controls,
                       "M7b: OECD Z → safe rate (KT)")
    if r_eq: all_results.append(r_eq)
    if r_rate: all_results.append(r_rate)

    # ── Build table ──
    build_table(all_results)

    print("\n" + "=" * 70)
    print("Phase 5 complete.")
    print("=" * 70)


def build_table(all_results):
    if not all_results:
        return

    key_vars = ['Z_1', 'Z_2', 'Z_3', 'old_dep', 'youth_dep',
                'Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']

    md = ["# Equity Market Results\n"]
    md.append("| Model | Dep Var | N | Countries | R² |")
    md.append("|---|---|---|---|---|")
    for r in all_results:
        md.append(f"| {r['label']} | {r['dep_var']} | {r['n_obs']} "
                  f"| {r['n_countries']} | {r['r_squared']:.3f} |")

    md.append("\n## Key Coefficients\n")
    md.append("| Model | Variable | Coef | SE | p-value | Sig |")
    md.append("|---|---|---|---|---|---|")
    for r in all_results:
        for var in key_vars:
            ckey = f'coef_{var}'
            if ckey in r:
                p = r[f'p_{var}']
                md.append(f"| {r['label']} | {var} | {r[ckey]:.4f} "
                          f"| {r[f'se_{var}']:.4f} | {p:.4f} | {stars(p)} |")

    md.append(f"\n*Controls: {', '.join(CONTROLS)}*")

    out = TABLES_DIR / "equity.md"
    out.write_text('\n'.join(md))
    print(f"  Saved: {out}")


if __name__ == "__main__":
    main()
